from itertools import groupby
from functools import reduce

import networkx as nx

class Port(object):
    def __init__(self, nid: int, pid: int):
        self.nid = nid
        self.pid = pid

class State(object):
    def __init__(self, sid: int, nid: int, rstate = {}):
        self.sid = sid
        self.nid = nid
        self.rstate = rstate

    def __lt__(self, rhs):
        return (self.nid < rhs.nid)

class Symbol(object):
    def __init__(self, pid: int, rindex: set[int] = set()):
        self.pid = pid
        self.rindex = rindex

class DFA(object):
    def __init__(self, states,
                 transitions,
                 initials,
                 acceptings):
        self.states = states
        self.transitions = transitions
        self.initials = initials
        self.accpetings = acceptings

        ts = transitions
        flat = {(sx, sym, ts[sx][sym]) for sx in ts for sym in ts[sx]}

        self.n_states = len(self.states)
        self.n_transitions = len(flat)

        flat = groupby(sorted(flat, key=lambda t: t[2]), lambda t: t[2])
        flat = {sy: [(t[0], t[1]) for t in ts] for sy, ts in flat}

        self.rtransitions = flat

    def next(self, state, symbol):
        return self.transitions.get(state, {}).get(symbol, None)

    def nexts(self, state):
        return self.transitions.get(state, {})

    def prevs(self, state):
        return self.rtransitions.get(state, [])

    def draw(self, filename):
        import matplotlib.pyplot as plt
        g = nx.DiGraph()

        for s in self.states:
            g.add_node('(%s, %s)' % (s.nid, s.sid))
        for s in self.transitions:
            for sym in self.nexts(s):
                d = self.next(s, sym)
                if d is None:
                    continue
                g.add_edge('(%s, %s)' % (s.nid, s.sid),
                           '(%s, %s)' % (d.nid, d.sid), sym = sym.pid)
        nx.draw(g)

        plt.savefig(filename)

class DFAOpt(object):
    def __init__(self):
        pass

    def optimize(self, DFA: list[DFA]):
        K = len(DFA)
        initials = [(s.nid, {i: s}) for i in range(K) for s in DFA[i].initials]
        initials = sorted(initials, key=lambda t: t[0])

        flatten = lambda L: {k: v for d in L for k, v in d.items()}
        initials = groupby(initials, key=lambda t: t[0])
        initials = [(key, flatten(map(lambda d: d[1], vs))) for key, vs in initials]
        sid = 0

        initials = [State(sid, nid, rs) for nid, rs in initials]
        states = []
        transitions = {}

        pending = {}
        Q = [s for s in initials]
        while len(Q) > 0:
            s, Q = Q[0], Q[1:]
            states += [s]
            transitions[s] = {}

            rs = s.rstate
            # print(rs)
            nexts = []
            for i in rs:
                nsi = DFA[i].nexts(rs[i])
                nexts += [(t.pid, i, nsi[t]) for t in nsi]
            candidates = []
            for pid, i, ns in nexts:
                if ns not in pending:
                    pending[ns] = []
                pending[ns] += [(s, pid)]
                if len(pending[ns]) == len(DFA[i].prevs(ns)):
                    # all upstream states are processed
                    key = set(pending[ns])
                    del pending[ns]
                    candidates += [(key, ns.nid, i, ns)]
            candidates = sorted(candidates, key=lambda t: (t[1], t[0]))
            for key, rs in groupby(candidates, key=lambda t: (t[1], t[0])):
                ts, nid = key[1], key[0]
                rs = {i: ns for _, _, i, ns in rs}
                sid += 1
                # print(ts, nid, {i for i in rs})
                ns = State(sid, nid, rs)
                for s, pid in ts:
                    transitions[s][Symbol(pid, rs)] = ns
                Q += [ns]
        return initials, states, transitions

if __name__ == '__main__':
    ss1 = [State(1, 1), State(2, 2), State(3, 3), State(4, 4)]
    dfa1 = DFA(
        states = {
            ss1[0], ss1[1], ss1[2], ss1[3]
        },
        transitions = {
            ss1[0]: { Symbol(1): ss1[1], Symbol(2): ss1[2] },
            ss1[1]: { Symbol(1): ss1[3] },
            ss1[2]: { Symbol(1): ss1[3] }
        },
        initials = [ ss1[0] ],
        acceptings = [ ss1[3] ]
    )

    ss2 = [State(1, 1), State(2, 2), State(3, 3), State(4, 4)]
    dfa2 = DFA(
        states = {
            ss2[0], ss2[1], ss2[2], ss2[3]
        },
        transitions = {
            ss2[0]: { Symbol(1): ss2[1], Symbol(2): ss2[2] },
            ss2[1]: { Symbol(1): ss2[3] },
            ss2[2]: { Symbol(1): ss2[3] }
        },
        initials = [ ss2[0] ],
        acceptings = [ ss2[3] ]
    )

    ss3 = [State(1, 1), State(2, 2), State(4, 4)]
    dfa3 = DFA(
        states = {
            ss3[0], ss3[1], ss3[2]
        },
        transitions = {
            ss3[0]: { Symbol(1): ss3[1] },
            ss3[1]: { Symbol(1): ss3[2] }
        },
        initials = [ ss3[0] ],
        acceptings = [ ss3[2] ]
    )

    opt = DFAOpt()
    initials, states, transitions = opt.optimize([dfa1, dfa2, dfa3])
    print(states)
    print(len(states))
